Conditional Random Field (CRF)

您所在的位置:网站首页 pytorch-crf assert mask[0]all Conditional Random Field (CRF)

Conditional Random Field (CRF)

2023-10-15 23:48| 来源: 网络整理| 查看: 265

Conditional Random Field (CRF)

Note

Copyright pytorch-crf from https://github.com/kmkurn/pytorch-crf.

Licensed under the MIT License.

CRF

This module implements a Conditional Random Field (CRF). The forward computation of this class computes the log likelihood of the given sequence of tags and emission score tensor. This class also has CRF.decode() method which finds the best tag sequence given an emission score tensor using Viterbi algorithm.

Attributes:

num_tags: An integer indicating the number of tags to be predicted.

batch_first: A boolean variable indicating whether or not splitting the data in batches.

start_transitions: An nn.Parameter matrix containing the start transition score tensor of size (num_tags,).

end_transitions: An nn.Parameter matrix containing the end transition score tensor of size (num_tags,).

transitions: An nn.Parameter matrix indicating the score tensor of size (num_tags, num_tags).

class CRF(nn.Module): """Conditional Random Field (CRF) module. This module implements a Conditional Random Field (CRF). The forward computation of this class computes the log likelihood of the given sequence of tags and emission score tensor. This class also has `CRF.decode()` method which finds the best tag sequence given an emission score tensor using Viterbi algorithm. Attributes: num_tags (`int`): An integer indicating the number of tags to be predicted. batch_first (`bool`): A boolean variable indicating whether or not splitting the data in batches. start_transitions (`nn.Parameter`): An `nn.Parameter` matrix containing the start transition score tensor of size `(num_tags,)`. end_transitions (`nn.Parameter`): An `nn.Parameter` matrix containing the end transition score tensor of size `(num_tags,)`. transitions (`nn.Parameter`): An `nn.Parameter` matrix indicating the score tensor of size `(num_tags, num_tags)`. """ def __init__(self, num_tags: int, batch_first: bool = False) -> None: """Constructs a `CRF`.""" if num_tags None: """Initialize the transition parameters. The parameters will be initialized randomly from a uniform distribution between -0.1 and 0.1. """ nn.init.uniform_(self.start_transitions, -0.1, 0.1) nn.init.uniform_(self.end_transitions, -0.1, 0.1) nn.init.uniform_(self.transitions, -0.1, 0.1) def __repr__(self) -> str: """Displays the class name and the number of tags.""" return f'{self.__class__.__name__}(num_tags={self.num_tags})' def forward(self, emissions: torch.Tensor, tags: torch.LongTensor, mask: Optional[torch.ByteTensor] = None, reduction: str = 'sum') -> torch.Tensor: """Compute the conditional log likelihood of a sequence of tags given emission scores.""" self._validate(emissions, tags=tags, mask=mask) if reduction not in ('none', 'sum', 'mean', 'token_mean'): raise ValueError(f'invalid reduction: {reduction}') if mask is None: mask = torch.ones_like(tags, dtype=torch.uint8) if self.batch_first: emissions = emissions.transpose(0, 1) tags = tags.transpose(0, 1) mask = mask.transpose(0, 1) # shape: (batch_size,) numerator = self._compute_score(emissions, tags, mask) # shape: (batch_size,) denominator = self._compute_normalizer(emissions, mask) # shape: (batch_size,) llh = numerator - denominator if reduction == 'none': return llh if reduction == 'sum': return llh.sum() if reduction == 'mean': return llh.mean() assert reduction == 'token_mean' return llh.sum() / mask.type_as(emissions).sum() def decode(self, emissions: torch.Tensor, mask: Optional[torch.ByteTensor] = None) -> List[List[int]]: """Find the most likely tag sequence using Viterbi algorithm.""" self._validate(emissions, mask=mask) if mask is None: mask = emissions.new_ones(emissions.shape[:2], dtype=torch.uint8) if self.batch_first: emissions = emissions.transpose(0, 1) mask = mask.transpose(0, 1) return self._viterbi_decode(emissions, mask) def _validate(self, emissions: torch.Tensor, tags: Optional[torch.LongTensor] = None, mask: Optional[torch.ByteTensor] = None) -> None: """Validates the emission dimension and whether its slice satisfies tag number, tag shape and mask shape.""" if emissions.dim() != 3: raise ValueError(f'emissions must have dimension of 3, got {emissions.dim()}') if emissions.size(2) != self.num_tags: raise ValueError( f'expected last dimension of emissions is {self.num_tags}, ' f'got {emissions.size(2)}') if tags is not None: if emissions.shape[:2] != tags.shape: raise ValueError( 'the first two dimensions of emissions and tags must match, ' f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}') if mask is not None: if emissions.shape[:2] != mask.shape: raise ValueError( 'the first two dimensions of emissions and mask must match, ' f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}') no_empty_seq = not self.batch_first and mask[0].all() no_empty_seq_bf = self.batch_first and mask[:, 0].all() if not no_empty_seq and not no_empty_seq_bf: raise ValueError('mask of the first timestep must all be on') def _compute_score(self, emissions: torch.Tensor, tags: torch.LongTensor, mask: torch.ByteTensor) -> torch.Tensor: """Computes the score based on the emission and transition matrix.""" # emissions: (seq_length, batch_size, num_tags) # tags: (seq_length, batch_size) # mask: (seq_length, batch_size) assert emissions.dim() == 3 and tags.dim() == 2 assert emissions.shape[:2] == tags.shape assert emissions.size(2) == self.num_tags assert mask.shape == tags.shape assert mask[0].all() seq_length, batch_size = tags.shape mask = mask.type_as(emissions) # Start transition score and first emission # shape: (batch_size,) score = self.start_transitions[tags[0]] score += emissions[0, torch.arange(batch_size), tags[0]] for i in range(1, seq_length): # Transition score to next tag, only added if next timestep is valid (mask == 1) # shape: (batch_size,) score += self.transitions[tags[i - 1], tags[i]] * mask[i] # Emission score for next tag, only added if next timestep is valid (mask == 1) # shape: (batch_size,) score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i] # End transition score # shape: (batch_size,) seq_ends = mask.long().sum(dim=0) - 1 # shape: (batch_size,) last_tags = tags[seq_ends, torch.arange(batch_size)] # shape: (batch_size,) score += self.end_transitions[last_tags] return score def _compute_normalizer(self, emissions: torch.Tensor, mask: torch.ByteTensor) -> torch.Tensor: """Compute the log-sum-exp score.""" # emissions: (seq_length, batch_size, num_tags) # mask: (seq_length, batch_size) assert emissions.dim() == 3 and mask.dim() == 2 assert emissions.shape[:2] == mask.shape assert emissions.size(2) == self.num_tags assert mask[0].all() seq_length = emissions.size(0) # Start transition score and first emission; score has size of # (batch_size, num_tags) where for each batch, the j-th column stores # the score that the first timestep has tag j # shape: (batch_size, num_tags) score = self.start_transitions + emissions[0] for i in range(1, seq_length): # Broadcast score for every possible next tag # shape: (batch_size, num_tags, 1) broadcast_score = score.unsqueeze(2) # Broadcast emission score for every possible current tag # shape: (batch_size, 1, num_tags) broadcast_emissions = emissions[i].unsqueeze(1) # Compute the score tensor of size (batch_size, num_tags, num_tags) where # for each sample, entry at row i and column j stores the sum of scores of all # possible tag sequences so far that end with transitioning from tag i to tag j # and emitting # shape: (batch_size, num_tags, num_tags) next_score = broadcast_score + self.transitions + broadcast_emissions # Sum over all possible current tags, but we're in score space, so a sum # becomes a log-sum-exp: for each sample, entry i stores the sum of scores of # all possible tag sequences so far, that end in tag i # shape: (batch_size, num_tags) next_score = torch.logsumexp(next_score, dim=1) # Set score to the next score if this timestep is valid (mask == 1) # shape: (batch_size, num_tags) score = torch.where(mask[i].unsqueeze(1), next_score, score) # End transition score # shape: (batch_size, num_tags) score += self.end_transitions # Sum (log-sum-exp) over all possible tags # shape: (batch_size,) return torch.logsumexp(score, dim=1) def _viterbi_decode(self, emissions: torch.FloatTensor, mask: torch.ByteTensor) -> List[List[int]]: """Decodes the optimal path using Viterbi algorithm.""" # emissions: (seq_length, batch_size, num_tags) # mask: (seq_length, batch_size) assert emissions.dim() == 3 and mask.dim() == 2 assert emissions.shape[:2] == mask.shape assert emissions.size(2) == self.num_tags assert mask[0].all() seq_length, batch_size = mask.shape # Start transition and first emission # shape: (batch_size, num_tags) score = self.start_transitions + emissions[0] history = [] # score is a tensor of size (batch_size, num_tags) where for every batch, # value at column j stores the score of the best tag sequence so far that ends # with tag j # history saves where the best tags candidate transitioned from; this is used # when we trace back the best tag sequence # Viterbi algorithm recursive case: we compute the score of the best tag sequence # for every possible next tag for i in range(1, seq_length): # Broadcast viterbi score for every possible next tag # shape: (batch_size, num_tags, 1) broadcast_score = score.unsqueeze(2) # Broadcast emission score for every possible current tag # shape: (batch_size, 1, num_tags) broadcast_emission = emissions[i].unsqueeze(1) # Compute the score tensor of size (batch_size, num_tags, num_tags) where # for each sample, entry at row i and column j stores the score of the best # tag sequence so far that ends with transitioning from tag i to tag j and emitting # shape: (batch_size, num_tags, num_tags) next_score = broadcast_score + self.transitions + broadcast_emission # Find the maximum score over all possible current tag # shape: (batch_size, num_tags) next_score, indices = next_score.max(dim=1) # Set score to the next score if this timestep is valid (mask == 1) # and save the index that produces the next score # shape: (batch_size, num_tags) score = torch.where(mask[i].unsqueeze(1), next_score, score) history.append(indices) # End transition score # shape: (batch_size, num_tags) score += self.end_transitions # Now, compute the best path for each sample # shape: (batch_size,) seq_ends = mask.long().sum(dim=0) - 1 best_tags_list = [] for idx in range(batch_size): # Find the tag which maximizes the score at the last timestep; this is our best tag # for the last timestep _, best_last_tag = score[idx].max(dim=0) best_tags = [best_last_tag.item()] # We trace back where the best last tag comes from, append that to our best tag # sequence, and trace it back again, and so on for hist in reversed(history[:seq_ends[idx]]): best_last_tag = hist[idx][best_tags[-1]] best_tags.append(best_last_tag.item()) # Reverse the order because we start from the last timestep best_tags.reverse() best_tags_list.append(best_tags) return best_tags_list


【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3